import torch
import math

def UltraE_fval_obj(triple, Corrupted_triple, vec_entity, vec_relation, vec_bias, config_yaml):
    p = config_yaml["datafeature"]["p"]
    beta = config_yaml["datafeature"]["beta"]
    margin = config_yaml["datafeature"]["margin"]


    n = triple.shape[0]
    handle_fcn = lambda S, CS: -1 / n * (torch.log(torch.sigmoid(S)) + torch.sum(torch.log(1 - torch.sigmoid(CS)), dim=0))
    eh = vec_entity[triple[:, 0], :].T
    et = vec_entity[triple[:, 1], :].T
    Ceh = vec_entity[Corrupted_triple[:, :, 0], :].T
    Cet = vec_entity[Corrupted_triple[:, :, 1], :].T
    R = vec_relation[triple[:, 2], :, :]

    bh = vec_bias[triple[:, 0]].T
    bt = vec_bias[triple[:, 1]].T
    Cbh = vec_bias[Corrupted_triple[:, :, 0]].T
    Cbt = vec_bias[Corrupted_triple[:, :, 1]].T

    [distU, CdistU] = UltraE_get_all_distU(eh, R, et, Ceh, Cet, config_yaml)
    S = UltraE_score(distU, bh, bt, margin)
    CS = UltraE_score(CdistU.squeeze().t(), Cbh.squeeze(), Cbt.squeeze(), margin)

    fval = handle_fcn(S, CS)
    fval[torch.isnan(fval)] = 0
    fval = torch.sum(fval)
    return fval

def dproj(p, beta, X):
    Proj_phi = lambda x:torch.cat([x[:, :p], beta * x[:, p:] / (torch.linalg.norm(x[:, p:], ord=2, dim=-1, keepdim=True))], dim=1)
    Proj_inversephi = lambda z: torch.cat([z[:, :p], torch.sqrt(abs(beta) + torch.linalg.norm(z[:, :p], ord=2, dim=-1, keepdim=True) ** 2) / beta * z[:, p:]], dim=1)
    z = Proj_inversephi(Proj_phi(X))
    return z

def UltraE_get_all_distU(eh, R, et, Ceh, Cet, config_yaml):
    p = config_yaml["datafeature"]["p"]
    beta = config_yaml["datafeature"]["beta"]
    k = Ceh.shape[1]
    distU = simple_dist_UltraE(eh, R, et, p, beta)
    CdistU = torch.zeros([distU.shape[0], k]).to(R.device)
    for kk in range(k):
        CdistU[:, kk] = simple_dist_UltraE(Ceh[:, kk, :], R, Cet[:, kk, :], p, beta)
    return distU, CdistU

def simple_dist_UltraE(eh, R, et, p, beta):
    Reh = torch.sum(R * eh.transpose(0, 1).unsqueeze(1), dim=2)
    distU13, distU14 = dist_UltraE(Reh, et, p, beta)
    distU = torch.min(torch.vstack([distU13, distU14]), dim=0)[0]
    return distU

def dist_UltraE(Reh, et, p, beta):
    distU13 = dist13(Reh, et, p, beta)
    distU14 = dist14(Reh, et, p, beta)
    return distU13, distU14

def dist13(x, y, p, beta):
    distU = Sdist(y.t(), rhob_a(y, x, p).t(), beta, p) + Sdist(rhob_a(y, x, p).t(), x, beta, p)
    return distU


def dist14(x, y, p, beta):
    distU = Sdist(x, rhob_a(x.t(), y.t(), p).t(), beta, p) + Sdist(rhob_a(x.t(), y.t(), p).t(), y.t(), beta, p)
    return distU

def rhob_a(a, b, p):
    norma_p = torch.linalg.norm(a[0:p,:], ord=2, dim=0)
    normb_p = torch.linalg.norm(b[:,0:p], ord=2, dim=1)
    bb = torch.cat([a[:p,:], b.t()[p:,:].squeeze() * norma_p / normb_p],dim=0)
    return bb

def Sdist(A, B, beta, p):
    beta = torch.tensor(beta).to(A.device)
    temp = qdot(A, B.t(), p) / beta
    mask = torch.abs(temp) < 1
    y = torch.zeros(temp.shape[0]).to(A.device)
    y[mask] = torch.sqrt(torch.abs(beta)) * torch.acos(torch.abs(temp[mask]))
    y[~mask] = torch.sqrt(torch.abs(beta)) * torch.acosh(torch.abs(temp[~mask]))
    return y

def qdot(A, B, p):
    AB = torch.mul(A, B.t())
    y = -torch.sum(AB[:, p:], dim=1) + torch.sum(AB[:, :p], dim=1)
    return y

def UltraE_score(distU, bh, bt, margin):
    S = -distU**2+bh+bt+margin
    return S

def compute_distance_matrix(A, B, beta, p):
    # beta = torch.tensor(-beta)
    K = qdot(A, B.t(), p) / abs(beta) ** 2
    epsilon = 0.00001

    hyperbolic_indices = K < -1.0 - epsilon
    euclidean_indices = (K < -1.0 + epsilon) & (~hyperbolic_indices)
    positive_similarity = K >= 0.0
    spherical_indices = (~positive_similarity) & (~(K < -1.0 + epsilon))
    K[hyperbolic_indices] = beta * torch.acosh(-K[hyperbolic_indices])
    K[euclidean_indices] = beta * torch.abs(2.0 * (1.0 + K[euclidean_indices]))
    K[positive_similarity] = beta * (math.pi / 2 + K[positive_similarity])
    K[spherical_indices] = beta * torch.acos(-K[spherical_indices])
    return K

def Utest(X, p, c):
    err = torch.sum(qdot(X, X.t(), p) - c)
    return err